Create shared Megatron calibration forward loop for prune / quantize#1501
Create shared Megatron calibration forward loop for prune / quantize#1501kevalmorabia97 wants to merge 4 commits into
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR migrates calibration logic from HuggingFace-based utilities to a new Megatron-Core calibration forward-loop. The change adds ChangesMegatron Calibration Migration
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
/ok to test df7ab63 |
|
Two WARs for modelopt <= 0.44 (fixed upstream in NVIDIA/Model-Optimizer#1501): - `prune.py`: after `import_mcore_gpt_from_hf` returns, walk the model and copy `model.layers.{i}.input_layernorm.weight` and `model.layers.{i}.post_attention_layernorm.weight` from HF into the fused `TELayerNormColumnParallelLinear.layer_norm_weight` parameters on `attention.linear_qkv` and `mlp.linear_fc1`. Without this the fused LayerNorm weights stay at random init for GPT-family models (Qwen3, Llama, ...) since modelopt 0.44's importer only loads `fused_norm` for Nemotron-H, leaving post-prune MMLU at chance. The WAR fails soft on missing HF keys, so it is a no-op on Nemotron-H (which uses `backbone.layers.{i}.norm.weight`). - `mmlu.py`: load `modelopt.torch.utils.plugins.megatron_generate` via `importlib.import_module` to grab the submodule rather than the function the package re-exports under the same name. The previous `from ... import megatron_generate as _mg_plugin` form raised `AttributeError: 'function' object has no attribute 'broadcast_from_last_pipeline_stage'` at import time. Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1501 +/- ##
===========================================
- Coverage 76.79% 60.35% -16.44%
===========================================
Files 474 476 +2
Lines 51560 52602 +1042
===========================================
- Hits 39593 31750 -7843
- Misses 11967 20852 +8885
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
`get_dataset_dataloader` tokenizes each calibration sample individually with `padding=True, truncation=True, max_length=512`. For long-document datasets like cnn_dailymail (typical article: ~700-1000 tokens), that truncates most of each article and pads short ones — feeding the importance estimator a heavily padded, contextually-impoverished batch. Pack samples into uniform `calib_max_sequence_length` chunks from the concatenated token stream (with `eos_token_id` as document separator) the way Megatron-Bridge's calibration loop does. This exposes the model to many more distinct contexts per `calib_size` samples and eliminates padding-token contamination of activation statistics. Measured impact on Qwen3-8B pruned to 30L/3584/11776 (5.99B params): before (trunc+pad): MMLU 0.486 after (packed): MMLU 0.544 (+5.8 pts, M-Bridge ref 0.563) The proper upstream fix is to add a `pack` mode to `get_dataset_dataloader` in modelopt (NVIDIA/Model-Optimizer#1501); this inline change makes prune.py work today against released modelopt 0.44.0. Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
|
/ok to test 20d3c5b |
20d3c5b to
8b537ce
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 59-60: The current enablement enables HybridModel defaults but
pattern updates and metric accounting are still gated by isinstance(...,
MambaModel), causing plain HybridModel instances to skip hybrid-pattern logic;
update the checks in the functions that perform hybrid pattern updates and
candidate metric accounting (the places currently using isinstance(obj,
MambaModel)) to detect HybridModel instead (e.g., isinstance(obj, HybridModel))
so plain HybridModel instances get the same pattern handling and metric updates,
and if any MambaModel-specific behavior is required keep a secondary
isinstance(obj, MambaModel) branch for those special cases.
In `@modelopt/torch/utils/dataset_utils.py`:
- Around line 492-526: The code can silently produce zero-length output when
token_stream is too short: after computing n_chunks from token_stream, check for
the edge cases and fail fast or warn; specifically, in the packing branch after
computing n_chunks (and before building input_ids/batch_encoded) add logic that
raises a clear exception if n_chunks == 0 (mentioning tokenizer.encode,
token_stream, total_chunks, and max_sample_length) and log/warn when n_chunks <
total_chunks to inform the caller they received fewer chunks than requested;
ensure the exception/warning uses the existing logging mechanism or raises a
ValueError so callers cannot proceed with empty tensors.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: d3ed2625-6222-473c-80cd-32fb4d6fbd4c
📒 Files selected for processing (9)
CHANGELOG.rstmodelopt/torch/export/plugins/mcore_qwen.pymodelopt/torch/export/plugins/megatron_importer.pymodelopt/torch/nas/plugins/megatron.pymodelopt/torch/prune/plugins/mcore_minitron.pymodelopt/torch/utils/dataset_utils.pymodelopt/torch/utils/plugins/megatron_generate.pymodelopt/torch/utils/plugins/megatron_mmlu.pytests/unit/torch/utils/test_dataset_utils.py
|
/claude review |
Claude review summaryFindings: CRITICAL: 1, IMPORTANT: 2, SUGGESTION: 2 Most impactful
Strengths
Risk assessmentMedium. The Qwen3 / Nemotron-H prune paths the PR explicitly validated end-to-end look solid, and the calibration-packing primitive is well-isolated behind an opt-in flag. The pre-existing |
8b537ce to
b70423f
Compare
d174a08 to
ca3af81
Compare
c6e4e98 to
5e1b424
Compare
|
@CodeRabbit review |
|
/claude review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Claude review summary
Findings: CRITICAL: 1, IMPORTANT: 1, SUGGESTION: 3.
The shared calibration loop is a clean consolidation and the experimental data backs the design. One blocker, one back-compat side effect, and a few smaller notes — all in the new megatron_calibration.py.
Highest-impact
-
CRITICAL Algorithm — padding-direction assumption (
megatron_calibration.py:131-144). The per-row trim slicesids[b, :real_len]and the no-padding branch overwritesids[:, -1], both of which assume right-padded sequences. Butget_dataset_dataloader(the function this loop calls) explicitly recommendspadding_side="left"and warns when it isn't (dataset_utils.py:593-596). With left-padding, the trim selects padding tokens instead of real tokens — re-introducing the exact contamination the PR is trying to eliminate. Your Qwen3-8B numbers move in the right direction because that tokenizer happens to default to right-padding, but a user who follows the dataloader's documented recommendation silently regresses. Capturepadding_sideonce and branch the slice (ids[b, -real_len:]for left-padding); also worth a regression test row whoseattention_maskhas zeros to lock the behavior in. -
IMPORTANT Compatibility — caller-tokenizer mutation (
megatron_calibration.py:87-88). Settingtokenizer.pad_token = tokenizer.eos_tokenmutates the caller's tokenizer object before the dataloader's deepcopy. The unwrapped tokenizer inprune_minitron.pyis reused downstream (MMLU eval, export). Tokenizers whosepad_tokenwasNonewill silently start emitting EOS as their pad token in every later code path that touches them. Local-copy the tokenizer before the mutation, or set the field on the dataloader-internal deepcopy only.
SUGGESTIONs (non-blocking)
- Docstring claim that the loop matches
GPTSFTDataset(add_eos=True)is slightly off — that flag appends EOS, the code overwrites the last real token. Reword or actually append when underseq_length. bool((mask == 0).any().item())andmask[b].sum().item()cause a CPU-GPU sync per batch / per row. Precompute real-lengths once on device into a CPU list at builder time to drop ~2*num_batches syncs.- The CP-rank slice inside
_forward_loopis dead code in the tested configs (CP=1) and would actually break under CP>1 because different CP ranks would compute different per-rowreal_lenand callmegatron_prefill(a collective) with shape-mismatched inputs. Either run the trim before the CP slice, or drop the CP-shard call until CP>1 is validated.
Backward compatibility
- Public API change: removing
get_hf_mbridge_calibration_loopfrommodelopt.torch.utils.plugins.mbridgeis a breaking change for any out-of-tree caller. Acceptable since the in-tree call sites are migrated and there's no schema/checkpoint impact, but flagging it for the maintainer's awareness. examples/megatron_bridge/prune_minitron.pyrenames--calib_mbs/--calib_gbsto--calib_batch_size. Existing user scripts/CI that invoke this example with the old flags will fail at parse time. README is updated in the same PR — fine.
Risk
Low-to-moderate. The CRITICAL is a correctness contradiction with the dataloader the loop relies on; once that's reconciled and the tokenizer mutation is confined to a local copy, this is a clear quality win over the previous bespoke loops.
Replaces the bespoke calibration loops in M-LM and M-Bridge prune/quantize example scripts with a single shared util, ``modelopt.torch.utils.plugins.megatron_calibration.get_megatron_calibration_forward_loop``. The shared loop emits one sample per row (via ``get_dataset_dataloader``), trims each row to its real length using the dataloader's attention mask, and forces EOS at the trimmed last position before forwarding via ``megatron_prefill(skip_return_logits=True)``. Matches legacy ``GPTSFTDataset(add_eos=True)`` semantics exactly. Samples are sorted by real length descending so front batches are mostly full-length (true batched forward); back batches that contain padding fall through to per-row forward to keep padding-token activations out of the calibration hook stream. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
5e1b424 to
d1597cf
Compare
|
/claude review Addressed all 5 Claude review comments in
Validation: re-ran M-LM prune on Qwen3-8B (nemotron-post-training-dataset-v2, seq=4096, calib=512, bs=16) post-fix → MMLU 0.584, vs 0.587 before fixes. Δ -0.3 pts, well within MMLU 5% noise floor (~±0.6 pt stderr). All fixes are quality-neutral (the boolean gather is numerically identical to |
There was a problem hiding this comment.
Claude review summary
Findings: CRITICAL: 0, IMPORTANT: 2, SUGGESTION: 2
The refactor is well-motivated: the unified loop is simpler than the old M-Bridge evaluate_and_print_results pipeline, the per-row trim + EOS-at-end matches GPTSFTDataset(add_eos=True), and the sort-by-length / per-batch padding-fast-path is a clean way to get batched throughput without polluting calibration stats with padding. The MMLU numbers in the PR description back this up.
Most impactful findings
-
Public API removal without deprecation —
get_hf_mbridge_calibration_loopis dropped frommodelopt.torch.utils.plugins.mbridge.__all__outright. Out-of-tree callers that pinned a previous example version will hitImportErroron upgrade with no migration hint. A one-release deprecation shim is cheap. -
inference_batch_sizedefaulting tocalib_batch_sizesilently changesprune_target_memory_mbmath. Previously--calib_mbswas hard-pinned to 1; now--calib_batch_sizeis unrestricted and the PR description recommends 16 for throughput. A user who follows that recommendation without also setting--inference_batch_sizegets a 16× larger KV-cache term in the memory constraint — i.e., a different prune target — silently. Either defaultinference_batch_sizeto 1 unconditionally, or warn loudly when this divergence is implied.
Non-blocking
- CP > 1 is not really supported by the
get_batch_on_this_cp_rank+megatron_prefillcombination (megatron_prefillrebuildsattention_mask/position_idsfrom the locally-sliced length). Fine while the validated configs are CP=1, worth either an assert or removing the CP wrapper to avoid future foot-gun. - Minor docstring nit: "no logits compute" overstates what
skip_return_logits=Truedoes.
Risk assessment
Algorithmic changes are validated by the MMLU sweeps in the PR description and the per-row trim / EOS / no-padding-in-calibration logic looks right. The risk surface is the user-facing surface: removed public symbol and the silent-default-shift on the prune memory target. Both are easy to address.
- Assert CP=1 in the calibration forward loop (megatron_prefill builds causal mask + position_ids over the local input tensor length, which would silently produce wrong activations under CP>1). Calibration sequences are short enough that CP doesn't help anyway. - Drop the get_batch_on_this_cp_rank call (was a no-op under CP=1 and broken under CP>1 — the per-row branch would have produced rank-local trimmed lengths, violating the collective-shape assumption). - Docstring: "logits-free prefill" → "skips returning logits / loss compute; the LM head still runs and activation hooks still fire on every layer" (matches what skip_return_logits=True actually does). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
|
/claude review Context for this round:
|
There was a problem hiding this comment.
Claude re-review
Findings (this pass): CRITICAL: 0, IMPORTANT: 2 (both carried over from prior review, still unresolved in code), SUGGESTION: 0.
Re-verification of prior findings
The single new commit (Address Claude review: CP-assert + clarify prefill docstring) addressed the two SUGGESTION items from the 2026-05-20 review:
- ✅ CP > 1 footgun — runtime
RuntimeErrornow raised inside_forward_loopwhencontext_parallel_world_size != 1. - ✅ Docstring overstatement — module docstring now accurately describes
skip_return_logits=Trueas "skips returning logits / loss compute … the LM head still runs and activation hooks still fire."
The two IMPORTANT items are not addressed:
-
[IMPORTANT Compatibility]
get_hf_mbridge_calibration_loopis removed frommodelopt/torch/utils/plugins/mbridge.py:__all__outright with the implementation deleted. Out-of-tree callers pinned to a previous release will hitImportErroron upgrade with no migration hint. A one-release deprecation shim (re-export plusDeprecationWarningpointing toget_megatron_calibration_forward_loop) would be cheap insurance. -
[IMPORTANT Compatibility]
examples/megatron_bridge/prune_minitron.py:155-164defaults--inference_batch_sizeto--calib_batch_size. Previously--calib_mbswas hard-pinned tochoices=[1], soinference_batch_sizealways fell back to1— users could safely bump--calib_gbsfor throughput without affecting the prune memory target. Now, the PR description recommends--calib_batch_size 16(and the experimental table in the description usedcalib_bs 16for the production runs), but anyone following that recommendation without also setting--inference_batch_size 1will silently get a 16× larger KV-cache term in--prune_target_memory_mb— a different prune target than they intended. Defaultinference_batch_size=1unconditionally, or warn when divergence is implied.
Algorithm / mode-state re-trace
End-to-end trace of the new get_megatron_calibration_forward_loop against current state:
- Padding direction — boolean-mask gather (
ids[b][mask[b].bool()]) correctly handles both left- and right-padded tokenizers; the prior padding-direction concern is fully addressed. - Tokenizer mutation — local
copy.deepcopy(tokenizer)beforetokenizer.pad_token = tokenizer.eos_tokencorrectly isolates the mutation from the caller's tokenizer.get_dataset_dataloaderdoes its own deepcopy, so the caller is double-protected. - CPU-GPU sync —
lengths_cpuprecomputed once on CPU; no per-batch sync inside the forward hot loop. ✓ - Sort-by-length / per-batch padding-fast-path — for batches with no padding, batched forward; for any batch containing padding, falls through to per-row forward. Calibration statistics (amax / channel-importance) are order-invariant aggregates, so the re-ordering is bit-identical to un-sorted. ✓
- PP coordination — all PP ranks see the same dataloader output (deterministic
get_dataset_samples+shuffle=False) and callmegatron_prefillwith matching shapes per iteration. ✓ - EOS-at-row-end semantics — overwrites the row's last real token with EOS for both the per-row and batched paths. Docstring is honest about the under-cap-row content-token loss trade-off. ✓
- Loop invariants — partial-tail batches (
n % batch_size != 0) handled correctly viafor b in range(ids.shape[0]). Zero-real-token rows skipped viarow.shape[1] < 1continue.
Backward compatibility
- CLI rename
--calib_mbs/--calib_gbs→--calib_batch_size(item 2 above). - Public API removal of
get_hf_mbridge_calibration_loop(item 1 above). - No
modelopt_stateschema changes.
Risk
Low-to-moderate. The algorithm itself is sound and matches the experimental data in the PR description. Risk surface is the user-facing surface: a public-symbol removal with no shim, and a CLI default coupling that quietly changes the prune target when the user follows the PR's own throughput recommendation. Both are easy to address.
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Summary
Replaces the bespoke calibration loops in Megatron-LM and Megatron-Bridge prune / quantize example scripts with a single shared utility,
modelopt.torch.utils.plugins.megatron_calibration.get_megatron_calibration_forward_loop.The shared loop:
get_dataset_dataloader(one sample per row, batch-padded) — single source of truth for the calibration dataset surface.attention_mask, then forces EOS at the trimmed last position, matching MBridge'sGPTSFTDataset(add_eos=True)semantics exactly. Padding-token activations would otherwise be hooked into calibration statistics regardless of attention masking (FFN/layernorm fire on every position), causing a substantial MMLU regression on prune (-5 to -7 pts in our experiments).has_padding=False→ true batched forward, mbs > 1 throughput); back batches that contain padding fall through to per-row forward to keep calibration stats clean.megatron_prefill(skip_return_logits=True)(no logits compute, just activation flow for hooks).Migrates four call sites to the shared util:
examples/megatron_bridge/prune_minitron.pyMegatron-LM/examples/post_training/modelopt/{prune,quantize}.py(separate PR: NVIDIA/Megatron-LM#4881)Megatron-Bridge/examples/quantization/quantize.py(separate PR)Unified defaults across all four sites:
--calib-dataset nemotron-post-training-dataset-v2,--calib-size 1024,--calib-max-sequence-length 4096,--calib-batch-size 1. Conservative defaults sized for MoE pruning (top-K routing → fewer tokens per expert → more samples × longer seq needed for stable amax / scoring).Experimental results
Validated on Qwen3-8B (TP=1 PP=2 for prune; TP=2 PP=1 for MMLU eval) with the production shared loop vs the original per-example bespoke loops.
MMLU noise floor (binomial 2σ at acc ≈ 0.70):
M-LM Minitron prune (Qwen3-8B → 30L/3584/11776 ≈ 6B params)
pack=TrueWAR)M-LM NVFP4 quantize (
NVFP4_DEFAULT_CFG)get_calib_dataloaderpad+truncate)For reference,
hf_ptq.pyon the same nemotron-v2 / seq=4096 / calib=512 setup reaches 5% MMLU 0.707 / Full 0.712 at bs=1, confirming the M-LM and HF calibration paths agree within MMLU noise for Qwen3-8B.M-Bridge Minitron prune
get_hf_mbridge_calibration_loop, M-Bridge SFT pipeline)Conclusions
calib_size(512 → 1024) and longerseq(2048 → 4096) within noise for dense Qwen3-8B but kept as conservative defaults for MoE robustness.trim+EOSsemantically matchesGPTSFTDataset(add_eos=True), closing a prior M-Bridge prune regression introduced by earlierpack=Truevariants; (3) MoE-friendly conservative defaults across all four call sites.🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
New Features
Breaking Changes
--calib_mbs/--calib_gbsreplaced with--calib_batch_size.Documentation